<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  
  <link rel="shortcut icon" href="../../img/favicon.ico">
  <title>Addition RNN - Keras 中文文档</title>
  <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>

  <link rel="stylesheet" href="../../css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../css/theme_extra.css" type="text/css" />
  <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
  
  <script>
    // Current page data
    var mkdocs_page_name = "Addition RNN";
    var mkdocs_page_input_path = "examples/addition_rnn.md";
    var mkdocs_page_url = "/zh/examples/addition_rnn/";
  </script>
  
  <script src="../../js/jquery-2.1.1.min.js" defer></script>
  <script src="../../js/modernizr-2.8.3.min.js" defer></script>
  <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
  <script>hljs.initHighlightingOnLoad();</script> 
  
  <script>
      (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
      })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

      ga('create', 'UA-61785484-1', 'keras.io');
      ga('send', 'pageview');
  </script>
  
</head>

<body class="wy-body-for-nav" role="document">

  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
      <div class="wy-side-nav-search">
        <a href="../.." class="icon icon-home"> Keras 中文文档</a>
        <div role="search">
  <form id ="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" title="Type search term here" />
  </form>
</div>
      </div>

      <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
	<ul class="current">
	  
          
            <li class="toctree-l1">
		
    <a class="" href="../..">主页</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../why-use-keras/">为什么选择 Keras?</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">快速开始</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../getting-started/sequential-model-guide/">Sequential 顺序模型指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/functional-api-guide/">函数式 API 指引</a>
                </li>
                <li class="">
                    
    <a class="" href="../../getting-started/faq/">FAQ 常见问题解答</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">模型</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../models/about-keras-models/">关于 Keras 模型</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/sequential/">Sequential 顺序模型 API</a>
                </li>
                <li class="">
                    
    <a class="" href="../../models/model/">函数式 API</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">Layers</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../layers/about-keras-layers/">关于 Keras 网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/core/">核心网络层</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/convolutional/">卷积层 Convolutional</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/pooling/">池化层 Pooling</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/local/">局部连接层 Locally-connected</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/recurrent/">循环层 Recurrent</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/embeddings/">嵌入层 Embedding</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/merge/">融合层 Merge</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/advanced-activations/">高级激活层 Advanced Activations</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/normalization/">标准化层 Normalization</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/noise/">噪声层 Noise</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/wrappers/">层封装器 wrappers</a>
                </li>
                <li class="">
                    
    <a class="" href="../../layers/writing-your-own-keras-layers/">编写你自己的层</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">数据预处理</span>
    <ul class="subnav">
                <li class="">
                    
    <a class="" href="../../preprocessing/sequence/">序列预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/text/">文本预处理</a>
                </li>
                <li class="">
                    
    <a class="" href="../../preprocessing/image/">图像预处理</a>
                </li>
    </ul>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../losses/">损失函数 Losses</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../metrics/">评估标准 Metrics</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../optimizers/">优化器 Optimizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../activations/">激活函数 Activations</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../callbacks/">回调函数 Callbacks</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../datasets/">常用数据集 Datasets</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../applications/">应用 Applications</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../backend/">后端 Backend</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../initializers/">初始化 Initializers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../regularizers/">正则化 Regularizers</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../constraints/">约束 Constraints</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../visualization/">可视化 Visualization</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../scikit-learn-api/">Scikit-learn API</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../utils/">工具</a>
	    </li>
          
            <li class="toctree-l1">
		
    <a class="" href="../../contributing/">贡献</a>
	    </li>
          
            <li class="toctree-l1">
		
    <span class="caption-text">经典样例</span>
    <ul class="subnav">
                <li class=" current">
                    
    <a class="current" href="./">Addition RNN</a>
    <ul class="subnav">
            
    <li class="toctree-l3"><a href="#_1">实现一个用来执行加法的序列到序列学习模型</a></li>
    

    </ul>
                </li>
                <li class="">
                    
    <a class="" href="../babi_rnn/">Baby RNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../babi_memnn/">Baby MemNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn/">CIFAR-10 CNN</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_capsule/">CIFAR-10 CNN-Capsule</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_cnn_tfaugment2d/">CIFAR-10 CNN with augmentation (TF)</a>
                </li>
                <li class="">
                    
    <a class="" href="../cifar10_resnet/">CIFAR-10 ResNet</a>
                </li>
                <li class="">
                    
    <a class="" href="../conv_filter_visualization/">Convolution filter visualization</a>
                </li>
                <li class="">
                    
    <a class="" href="../image_ocr/">Image OCR</a>
                </li>
                <li class="">
                    
    <a class="" href="../imdb_bidirectional_lstm/">Bidirectional LSTM</a>
                </li>
    </ul>
	    </li>
          
        </ul>
      </div>
      &nbsp;
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
        <a href="../..">Keras 中文文档</a>
      </nav>

      
      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="breadcrumbs navigation">
  <ul class="wy-breadcrumbs">
    <li><a href="../..">Docs</a> &raquo;</li>
    
      
        
          <li>经典样例 &raquo;</li>
        
      
    
    <li>Addition RNN</li>
    <li class="wy-breadcrumbs-aside">
      
        <a href="https://github.com/keras-team/keras-docs-zh/edit/master/docs/examples/addition_rnn.md"
          class="icon icon-github"> Edit on GitHub</a>
      
    </li>
  </ul>
  <hr/>
</div>
          <div role="main">
            <div class="section">
              
                <h1 id="_1">实现一个用来执行加法的序列到序列学习模型</h1>
<p>输入: "535+61"</p>
<p>输出: "596"</p>
<p>使用重复的标记字符（空格）处理填充。</p>
<p>输入可以选择性地反转，它被认为可以提高许多任务的性能，例如：
<a href="http://arxiv.org/abs/1410.4615">Learning to Execute</a>
以及
<a href="http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf">Sequence to Sequence Learning with Neural Networks</a>。</p>
<p>从理论上讲，它引入了源和目标之间的短期依赖关系。</p>
<p>两个反转的数字 + 一个 LSTM 层（128个隐藏单元），在 55 个 epochs 后，5k 的训练样本取得了 99% 的训练/测试准确率。</p>
<p>三个反转的数字 + 一个 LSTM 层（128个隐藏单元），在 100 个 epochs 后，50k 的训练样本取得了 99% 的训练/测试准确率。</p>
<p>四个反转的数字 + 一个 LSTM 层（128个隐藏单元），在 20 个 epochs 后，400k 的训练样本取得了 99% 的训练/测试准确率。</p>
<p>五个反转的数字 + 一个 LSTM 层（128个隐藏单元），在 30 个 epochs 后，550k 的训练样本取得了 99% 的训练/测试准确率。</p>
<pre><code class="python">from __future__ import print_function
from keras.models import Sequential
from keras import layers
import numpy as np
from six.moves import range


class CharacterTable(object):
    &quot;&quot;&quot;给定一组字符：
    + 将它们编码为 one-hot 整数表示
    + 将 one-hot 或整数表示解码为字符输出
    + 将一个概率向量解码为字符输出
    &quot;&quot;&quot;
    def __init__(self, chars):
        &quot;&quot;&quot;初始化字符表。

        # 参数：
            chars: 可以出现在输入中的字符。
        &quot;&quot;&quot;
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, C, num_rows):
        &quot;&quot;&quot;给定字符串 C 的 one-hot 编码。

        # 参数
            C: 需要编码的字符串。
            num_rows: 返回的 one-hot 编码的行数。
                      这用来保证每个数据的行数相同。
        &quot;&quot;&quot;
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x

    def decode(self, x, calc_argmax=True):
        &quot;&quot;&quot;将给定的向量或 2D array 解码为它们的字符输出。

        # 参数
            x: 一个向量或 2D 概率数组或 one-hot 表示，
               或 字符索引的向量（如果 `calc_argmax=False`）。
            calc_argmax: 是否根据最大概率来找到字符，默认为 `True`。
        &quot;&quot;&quot;
        if calc_argmax:
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in x)


class colors:
    ok = '\033[92m'
    fail = '\033[91m'
    close = '\033[0m'

# 模型和数据的参数
TRAINING_SIZE = 50000
DIGITS = 3
REVERSE = True

# 输入的最大长度是 'int+int' (例如, '345+678'). int 的最大长度为 DIGITS。
MAXLEN = DIGITS + 1 + DIGITS

# 所有的数字，加上符号，以及用于填充的空格。
chars = '0123456789+ '
ctable = CharacterTable(chars)

questions = []
expected = []
seen = set()
print('Generating data...')
while len(questions) &lt; TRAINING_SIZE:
    f = lambda: int(''.join(np.random.choice(list('0123456789'))
                    for i in range(np.random.randint(1, DIGITS + 1))))
    a, b = f(), f()
    # 跳过任何已有的加法问题
    # 同事跳过任何 x+Y == Y+x 的情况(即排序)。
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    # 利用空格填充，是的长度始终为 MAXLEN。
    q = '{}+{}'.format(a, b)
    query = q + ' ' * (MAXLEN - len(q))
    ans = str(a + b)
    # 答案可能的最长长度为 DIGITS + 1。
    ans += ' ' * (DIGITS + 1 - len(ans))
    if REVERSE:
        # 反转查询，例如，'12+345  ' 变成 '  543+21'. 
        # (注意用于填充的空格)
        query = query[::-1]
    questions.append(query)
    expected.append(ans)
print('Total addition questions:', len(questions))

print('Vectorization...')
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, DIGITS + 1)

# 混洗 (x, y)，因为 x 的后半段几乎都是比较大的数字。
indices = np.arange(len(y))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]

# 显式地分离出 10% 的训练数据作为验证集。
split_at = len(x) - len(x) // 10
(x_train, x_val) = x[:split_at], x[split_at:]
(y_train, y_val) = y[:split_at], y[split_at:]

print('Training Data:')
print(x_train.shape)
print(y_train.shape)

print('Validation Data:')
print(x_val.shape)
print(y_val.shape)

# 可以尝试更改为 GRU, 或 SimpleRNN。
RNN = layers.LSTM
HIDDEN_SIZE = 128
BATCH_SIZE = 128
LAYERS = 1

print('Build model...')
model = Sequential()
# 利用 RNN 将输入序列「编码」为一个 HIDDEN_SIZE 长度的输出向量。
# 注意：在输入序列具有可变长度的情况下,
# 使用 input_shape=(None, num_feature).
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
# 作为解码器 RNN 的输入，为每个时间步重复地提供 RNN 的最后输出。
# 重复 'DIGITS + 1' 次，因为它是最大输出长度。
# 例如，当 DIGITS=3, 最大输出为 999+999=1998。
model.add(layers.RepeatVector(DIGITS + 1))
# 解码器 RNN 可以是多个堆叠的层，或一个单独的层。
for _ in range(LAYERS):
    # 通过设置 return_sequences 为 True, 将不仅返回最后一个输出，而是返回目前的所有输出，形式为(num_samples, timesteps, output_dim)。
    # 这是必须的，因为后面的 TimeDistributed 需要第一个维度是时间步。
    model.add(RNN(HIDDEN_SIZE, return_sequences=True))

# 将全连接层应用于输入的每个时间片。
# 对于输出序列的每一步，决定应选哪个字符。
model.add(layers.TimeDistributed(layers.Dense(len(chars), activation='softmax')))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.summary()

# 训练模型，并在每一代显示验证数据的预测。
for iteration in range(1, 200):
    print()
    print('-' * 50)
    print('Iteration', iteration)
    model.fit(x_train, y_train,
              batch_size=BATCH_SIZE,
              epochs=1,
              validation_data=(x_val, y_val))
    # 从随机验证集中选择 10 个样本，以便我们可以看到错误。
    for i in range(10):
        ind = np.random.randint(0, len(x_val))
        rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
        preds = model.predict_classes(rowx, verbose=0)
        q = ctable.decode(rowx[0])
        correct = ctable.decode(rowy[0])
        guess = ctable.decode(preds[0], calc_argmax=False)
        print('Q', q[::-1] if REVERSE else q, end=' ')
        print('T', correct, end=' ')
        if correct == guess:
            print(colors.ok + '☑' + colors.close, end=' ')
        else:
            print(colors.fail + '☒' + colors.close, end=' ')
        print(guess)
</code></pre>
              
            </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="../babi_rnn/" class="btn btn-neutral float-right" title="Baby RNN">Next <span class="icon icon-circle-arrow-right"></span></a>
      
      
        <a href="../../contributing/" class="btn btn-neutral" title="贡献"><span class="icon icon-circle-arrow-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <!-- Copyright etc -->
    
  </div>

  Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
      
        </div>
      </div>

    </section>

  </div>

  <div class="rst-versions" role="note" style="cursor: pointer">
    <span class="rst-current-version" data-toggle="rst-current-version">
      
          <a href="https://github.com/keras-team/keras-docs-zh/" class="fa fa-github" style="float: left; color: #fcfcfc"> GitHub</a>
      
      
        <span><a href="../../contributing/" style="color: #fcfcfc;">&laquo; Previous</a></span>
      
      
        <span style="margin-left: 15px"><a href="../babi_rnn/" style="color: #fcfcfc">Next &raquo;</a></span>
      
    </span>
</div>
    <script>var base_url = '../..';</script>
    <script src="../../js/theme.js" defer></script>
      <script src="../../search/main.js" defer></script>

</body>
</html>
